import torch
a = torch.ones([1,2])
b = torch.ones([1,2])
c = torch.cat([a,b],1) # 按列拼接
print(c)
d = torch.cat([a,b],0) # 按行拼接
print(d)
